from dis import Instruction
import os
from site import check_enableusersite
import sys
from datasets import load_dataset
import pickle
import seaborn as sns
import pandas as pd
import fire
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
import json
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import numpy as np
import random
import warnings
import matplotlib.pyplot as plt
import matplotlib

warnings.filterwarnings('ignore')

def get_output(
    model,
    instruction,
    tokenizer,
    input=None,
    temperature=0.5,
    top_p=0.2,
    top_k=40,
    num_beams=4,
    max_new_tokens=1,
    device='cuda'):
    if input:
        prompt = instruction + input
    else:
        prompt = instruction
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        pad_token_id=0

    )
    generation_output = model.generate(
        input_ids=input_ids,
        output_hidden_states=True,
        generation_config=generation_config,
        return_dict_in_generate=True,
        max_new_tokens=max_new_tokens,
        num_return_sequences=1
    )
    # print(generation_output)
    return generation_output

def select_different_items(list1, list2):
    # 确保两个列表的长度相同，否则抛出异常
    if len(list1) != len(list2):
        raise ValueError("Both lists must have the same length")
    
    # 随机选择第一个列表中的一个索引
    index1 = random.randint(0, len(list1) - 1)
    item1 = list1[index1]
    
    # 随机选择第二个列表中的一个索引，确保不同于 index1
    valid_indices = [i for i in range(len(list2)) if i != index1]
    index2 = random.choice(valid_indices)
    item2 = list2[index2]
    
    return item1, item2

def select_same_index_items(list1, list2):
    # 确保两个列表的长度相同，否则抛出异常
    if len(list1) != len(list2):
        raise ValueError("Both lists must have the same length")
    
    # 随机选择第一个列表中的一个索引
    index1 = random.randint(0, len(list1) - 1)
    item1 = list1[index1]
    
    item2 = list2[index1]
    
    return item1, item2

def get_r_lists_cossim(model,   tokenizer, datapath1, datapath2, seed, r=500, same_index=False):
    with open(datapath1, 'r') as f:
        sentences_1 = json.load(f)
    
    with open(datapath2, 'r') as f:
        sentences_2 = json.load(f)

    allcos = []
    for sss in range(r):
        random.seed(seed)
        seed = seed + 1
        if datapath1 == datapath2:
            instruction1, instruction2 = select_different_items(sentences_1, sentences_2)
            instruction1 = instruction1['instruction']
            instruction2 = instruction2['instruction']
        else:
            if same_index:
                instruction1, instruction2 = select_same_index_items(sentences_1, sentences_2)
                instruction1 = instruction1['instruction']
                instruction2 = instruction2['instruction']
            else:
                instruction1, instruction2 = select_different_items(sentences_1, sentences_2)
                instruction1 = instruction1['instruction']
                instruction2 = instruction2['instruction']

        all_vectors = []
        generation_output1 = get_output(model=model, instruction=instruction1, tokenizer=tokenizer)
        hs1 = generation_output1['hidden_states']

        for i in range(len(hs1[0])):
            if i == 0:
                continue
            all_vectors.append(hs1[0][i][0][-1])

        all_vectors2 = []
        generation_output2 = get_output(model=model, instruction=instruction2, tokenizer=tokenizer)
        hs2 = generation_output2['hidden_states']
        for i in range(len(hs2[0])):
            if i == 0:
                continue
            all_vectors2.append(hs2[0][i][0][-1])

        cso = []
        for k in range(len(all_vectors2)):
            try:
                a = all_vectors[k].cpu().detach().numpy()
                b = all_vectors2[k].cpu().detach().numpy()
            except:
                a = all_vectors[k].cpu().detach().to(torch.float32).numpy()
                b = all_vectors2[k].cpu().detach().to(torch.float32).numpy()
            cosine_similarity = float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
            cso.append(cosine_similarity)

        allcos.append(cso)

    print('end')
    return allcos

# 计算均值和标准差
def compute_stats(data):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    upper_bound = mean + std
    lower_bound = mean - std
    return mean, upper_bound, lower_bound

def compute_differ(mean_list):
    differences = np.diff(mean_list)
    differences = np.insert(differences, 0, np.nan)  # 在开头插入 NaN
    # print(differences)
    return differences

def plot_similarity(results, save_dir, language, fill=True):
    language_language = np.array(results["language-language-differ"])
    English_English = np.array(results["English-English-differ"])
    language_English_same = np.array(results["language-English-same"])
    language_English_differ = np.array(results["language-English-differ"])
    
    mean_language_language, language_language_up, language_language_down = compute_stats(language_language)
    mean_English_English, en_language_up, en_language_down = compute_stats(English_English)
    mean_language_English_same, language_English_same_up, language_English_same_down = compute_stats(language_English_same)
    mean_language_English_differ, language_English_differ_up, language_English_differ_down = compute_stats(language_English_differ)
    
    # Create figure 1
    fig, ax = plt.subplots(figsize=(8, 6))
    len_layers = language_language.shape[1]  # Assume the number of columns represents the number of layers
    
    # Plot language-language-differ
    ax.plot(mean_language_language, label=f"{language}-{language}-different-question", color="blue", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), language_language_up, language_language_down, where=(language_language_up > language_language_down), color="blue", alpha=0.2)
    
    # Plot English-English-differ
    ax.plot(mean_English_English, label="English-English-different-question", color="red", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), en_language_up, en_language_down, where=(en_language_up > en_language_down), color="red", alpha=0.2)
    
    # Plot language-English-same
    ax.plot(mean_language_English_same, label=f"English-{language}-same-question", color="green", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), language_English_same_up, language_English_same_down, where=(language_English_same_up > language_English_same_down), color="green", alpha=0.2)

    
    # Set title and labels
    ax.set_title("Layer-wise Average Cosine Similarity", fontsize="medium")
    ax.set_xlabel("Layer", fontsize="small")
    ax.set_ylabel("Cosine Similarity Value", fontsize="small")
    ax.tick_params(axis="both", labelsize="x-small")
    ax.set_xticks(np.arange(0, len_layers, 1))
    
    # Enable grid
    ax.grid(True, linewidth=0.5, linestyle="--")
    
    # Set legend in the lower-left corner
    ax.legend(loc="lower left", fontsize="small")
    
    # Save the figure
    save_path = f"{save_dir}/comparison_plot.png"
    plt.savefig(save_path, dpi=500)
        
    # Create figure 2
    fill = False
    fig, ax = plt.subplots(figsize=(8, 6))
    len_layers = language_language.shape[1]  # Assume the number of columns represents the number of layers
    
    # Plot language-language-differ
    ax.plot(mean_language_language, label=f"{language}-{language}-different-question", color="blue", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), language_language_up, language_language_down, where=(language_language_up > language_language_down), color="blue", alpha=0.2)
    
    # Plot English-English-differ
    ax.plot(mean_English_English, label="English-English-different-question", color="red", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), en_language_up, en_language_down, where=(en_language_up > en_language_down), color="red", alpha=0.2)
    
    
    # Plot language-English-same
    ax.plot(mean_language_English_same, label=f"English-{language}-same-question", color="green", linewidth=1.2)
    if fill:
        ax.fill_between(np.arange(len_layers), language_English_same_up, language_English_same_down, where=(language_English_same_up > language_English_same_down), color="green", alpha=0.2)
    
    
    # Set title and labels
    ax.set_title("Layer-wise Average Cosine Similarity", fontsize="medium")
    ax.set_xlabel("Layer", fontsize="small")
    ax.set_ylabel("Cosine Similarity Value", fontsize="small")
    ax.tick_params(axis="both", labelsize="x-small")
    ax.set_xticks(np.arange(0, len_layers, 1))
    
    # Enable grid
    ax.grid(True, linewidth=0.5, linestyle="--")
    
    # Set legend in the lower-left corner
    ax.legend(loc="lower left", fontsize="small")
    
    # Save the figure
    save_path = f"{save_dir}/comparison_plot_no_fill.png"
    plt.savefig(save_path, dpi=500)
    
    # Compute differences
    diff_same_question_language = mean_language_English_same - mean_language_language
    diff_same_question_English = mean_language_English_same - mean_English_English

    # Create a new figure for the differences
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot the differences
    ax.plot(diff_same_question_language, label=f"mean_{language}_English_same - mean_{language}_{language}", color="blue", linewidth=1.2)
    ax.plot(diff_same_question_English, label=f"mean_{language}_English_same - mean_English_English", color="red", linewidth=1.2)

    # Set labels, title, and legend
    ax.set_title("Differences in Layer-wise Cosine Similarity", fontsize="medium")
    ax.set_xlabel("Layer", fontsize="small")
    ax.set_ylabel("Difference Value", fontsize="small")
    ax.tick_params(axis="both", labelsize="x-small")
    ax.set_xticks(np.arange(0, len_layers, 1))

    # Add grid and legend
    ax.grid(True, linewidth=0.5, linestyle="--")
    ax.legend(loc="lower left", fontsize="small")

    # Save the new plot
    diff_save_path = f"{save_dir}/difference_plot.png"
    plt.savefig(diff_save_path, dpi=500)


    return mean_language_language, mean_English_English, mean_language_English_same, mean_language_English_differ



def main(
    en_path: str = 'normal.json',
    language_path: str = 'malicious.json',
    model_path: str = 'meta-llama/Llama-2-7bf',
    save_dir: str = 'cos_sims/',
    r: int = 500,
    language: str = 'Chinese'
):
    device_map = 'auto'

    if os.path.exists(save_dir) and os.path.isdir(save_dir):
        return 0
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map=device_map,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right", use_fast=False)

    allcos_language_language_pairs = get_r_lists_cossim(model,   tokenizer, language_path, language_path, 0, r, False)
    allcos_English_English_pairs = get_r_lists_cossim(model,   tokenizer, en_path, en_path, 1000, r, False)
    allcos_language_English_same_pairs = get_r_lists_cossim(model,   tokenizer, en_path, language_path, 2000, r, True)
    allcos_language_English_differ_pairs = get_r_lists_cossim(model,   tokenizer, en_path, language_path, 3000, r, False)
    # print(allcos_language_language_pairs)

    results = {
        "language-language-differ": allcos_language_language_pairs,
        "English-English-differ": allcos_English_English_pairs,
        "language-English-same": allcos_language_English_same_pairs,
        "language-English-differ": allcos_language_English_differ_pairs
    }

    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, 'all_cos.json')
    with open(save_path, 'w') as f:
        json.dump(results, f)
    
    mean_language_language, mean_English_English, mean_language_English_same, mean_language_English_differ = plot_similarity(results, save_dir, language)
    mean_results = {
        "mean_language_language": mean_language_language.tolist(),
        "mean_English_English": mean_English_English.tolist(),
        "mean_language_English_same": mean_language_English_same.tolist(),
        "mean_language_English_differ": mean_language_English_differ.tolist()
    }
    save_path = os.path.join(save_dir, 'mean_language.json')
    with open(save_path, 'w') as f:
        json.dump(mean_results, f)
    

fire.Fire(main)
